{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "import env\n",
    "from utils import KaggleCameraDataset, progress_iter\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# training data\n",
    "## load images (links)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_data = KaggleCameraDataset('../data/', train=True, lazy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## generate maximum number of non-overlapping patches from an image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def gen_patches(x, crop_size=64):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : (H, W, C) np.ndarray\n",
    "    crop_size : positive int\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    patches : (D, `crop_size`, `crop_size`, C)\n",
    "        where D = (H / `crop_size`)*(W / `crop_size`)\n",
    "    \"\"\"\n",
    "    H, W, C = x.shape\n",
    "    n_H = H / crop_size\n",
    "    n_W = W / crop_size\n",
    "    D = n_H * n_W\n",
    "    patches = np.zeros((D, crop_size, crop_size, C), dtype=np.uint8)\n",
    "    for i in xrange(n_H):\n",
    "        for j in xrange(n_W):\n",
    "            patches[i * n_W + j, ...] = x[i*crop_size:(i + 1)*crop_size, \n",
    "                                          j*crop_size:(j + 1)*crop_size, :]\n",
    "    return patches"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x, _ in train_data:\n",
    "    x = np.asarray(x, dtype=np.uint8)\n",
    "    plt.imshow(x)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patches = gen_patches(x)\n",
    "plt.imshow(Image.fromarray(patches[0]));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(Image.fromarray(patches[100]));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## function to compute patch score $Q(\\mathbf{x})$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def Q(x, alpha=0.7, beta=4., gamma=np.log(0.01)):\n",
    "    \"\"\"\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : (H, W, C) np.ndarray\n",
    "        [0; 1]-normalized image\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    Q : float\n",
    "    \"\"\"\n",
    "    Q = []\n",
    "    for c in xrange(x.shape[-1]):\n",
    "        mu = np.mean(x[:, :, c])\n",
    "        std = np.std(x[:, :, c])\n",
    "        q = alpha * beta * mu * (1. - mu) + (1. - alpha) * (1. - np.exp(gamma * std))\n",
    "        Q.append(q)\n",
    "    return np.mean(Q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.56467943787766683"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q(patches[0]/255.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.82710393207960398"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Q(patches[100]/255.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## carefully assemble most informative $K$ patches for each class\n",
    "### for one class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_K_patches_for_class(c, K=12800, buf_size=2000, crop_size=64, threshold_Q=0.7):\n",
    "    \n",
    "    best_patches = np.zeros((K + buf_size, crop_size, crop_size, 3), dtype=np.uint8)\n",
    "    best_Q_vals = np.zeros(K + buf_size)\n",
    "    pos = 0\n",
    "\n",
    "    X = [train_data.X[i] for i in xrange(len(train_data)) if train_data.y[i] == c]\n",
    "    \n",
    "    for x in progress_iter(X, verbose=True):    \n",
    "        # gen patches and computes scores\n",
    "        x = Image.open(x)\n",
    "        x = np.asarray(x, dtype=np.uint8)\n",
    "        patches = gen_patches(x)\n",
    "        Q_vals = np.asarray([Q(p/255.) for p in patches])\n",
    "\n",
    "        # filter out low-Q patches\n",
    "        patches = patches[Q_vals > threshold_Q]\n",
    "        Q_vals = Q_vals[Q_vals > threshold_Q]\n",
    "        Q_vals = -Q_vals  # store -Q values for sort\n",
    "        ind_local = Q_vals.argsort()\n",
    "        patches = patches[ind_local[:buf_size]]\n",
    "        Q_vals = Q_vals[ind_local[:buf_size]]\n",
    "        n_current = len(Q_vals)\n",
    "\n",
    "        # when pos < K, accumulators are not yet filled (with non-zeros)\n",
    "        if pos < K:\n",
    "            best_patches[pos:pos + n_current, ...] = patches\n",
    "            best_Q_vals[pos:pos + n_current] = Q_vals\n",
    "            pos += n_current\n",
    "        # otherwise we add new values to the end (after K), and sort based on Q-values\n",
    "        else:\n",
    "            best_patches[K:K + n_current] = patches\n",
    "            best_Q_vals[K:K + n_current] = Q_vals\n",
    "            ind = best_Q_vals.argsort()\n",
    "            best_patches = best_patches[ind]\n",
    "            best_Q_vals = best_Q_vals[ind]\n",
    "\n",
    "    if pos < K:\n",
    "        print \"Reduce threshold! There are only {0} patches out of {1}\".format(pos, K)\n",
    "    \n",
    "    return best_patches[:K]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### for all classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ce3a68368b84060883d228029aa2d46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "42ba38f58f42442bb691925fe1725baf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "567819c73c6b43c193ce418c1765cddf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99dbe30162c848ad9d248cba21040a5b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "592f06677c75440188ff93e7b5ae9cf2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c49e011a5c7d419bbcfaaa62ab2915db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d45f2d1af1d48dc80fa48c0a2c52016",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "655798b40c70443d988057e1bb2a0363",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d087e42c5b5c4c608432ce7199aed6a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f81cea728e55445ea27d38180527903c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "K = 12800\n",
    "K_total = 10 * K\n",
    "patches = np.zeros((K_total, 64, 64, 3), dtype=np.uint8)\n",
    "for c in xrange(10):\n",
    "    patches[c*K:(c + 1)*K] = get_K_patches_for_class(c, K=K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(128000, 64, 64, 3)"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "patches.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.save('../data/X_patches.npy', patches)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y = np.arange(10).repeat(K)\n",
    "np.save('../data/y_patches.npy', y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# test data\n",
    "## load images (links)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "test_data = KaggleCameraDataset('../data/', train=False, lazy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## generate 16 most informative patches from images and use them for predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0e12c0e31f6492bb6b6a8c565378b23",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "A Jupyter Widget"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "K = 16\n",
    "test_patches = np.zeros((K * len(test_data), 64, 64, 3), dtype=np.uint8)\n",
    "\n",
    "pos = 0\n",
    "for x, _ in progress_iter(test_data, verbose=True): \n",
    "    x = np.asarray(x, dtype=np.uint8)\n",
    "    patches = gen_patches(x)\n",
    "    Q_vals = np.asarray([Q(p/255.) for p in patches])\n",
    "    Q_vals = -Q_vals  # negate Q values for sort\n",
    "    ind = Q_vals.argsort()\n",
    "    patches = patches[ind[:K]]\n",
    "    Q_vals = Q_vals[ind[:K]]\n",
    "    test_patches[pos:pos + K] = patches\n",
    "    pos += K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.save('../data/X_test.npy', test_patches)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
